import os
import matplotlib.pyplot as plt
import numpy as np


def dict2value(dict_str):
    return float(dict_str.split()[1])


def read_log(file_path, is_skip=False):
    ulb_ratios = []
    ema_skip = False
    for line in open(file_path, 'r').readlines():
        if 'iteration USE_EMA: True,' in line:
            if is_skip:
                if ema_skip:
                    ema_skip = False
                    continue
                else: ema_skip = True
            
            blocks = line.split(',')
            ratio = dict2value(blocks[5])
            ulb_ratios.append(ratio)
    return ulb_ratios


def read_main_log(file_path):
    pseudo_acces, ulb_ratios = [], []
    for line in open(file_path, 'r').readlines():
        line = line.strip()
        if 'Additional logging info:' in line:
            blocks = line.replace(' Additional logging info:', ',').split(',')
            pseudo_acc = dict2value(blocks[2])
            pseudo_acces.append(round(pseudo_acc, 4))
        elif 'iteration USE_EMA: True,' in line:
            blocks = line.split(',')
            ratio = dict2value(blocks[7])
            ulb_ratios.append(ratio)
    return pseudo_acces, ulb_ratios


# ulb_ratio = read_log('ablation_aat_log/flexmatch_ic9600.txt')
psuedo_acc, ulb_ratio = read_main_log('ablation_aat_log/main_ic9600.txt')
length = len(ulb_ratio)
step = (length // 5)
print(np.array(ulb_ratio[:step]).mean(), np.array(ulb_ratio[step:2*step]).mean(), np.array(ulb_ratio[2*step:3*step]).mean(), np.array(ulb_ratio[3*step:4*step]).mean(), np.array(ulb_ratio[4*step:]).mean())
print(np.array(psuedo_acc[:step]).mean(), np.array(psuedo_acc[step:2*step]).mean(), np.array(psuedo_acc[2*step:3*step]).mean(), np.array(psuedo_acc[3*step:4*step]).mean(), np.array(psuedo_acc[4*step:]).mean())

